package com.idea.relax.boot.security.sql;


import com.idea.relax.tool.core.StringUtil;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.util.LinkedHashMap;
import java.util.Map;

/**
 * @author: 沉香
 * @date: 2023/4/5
 * @description:
 */
public class SqlInjectHttpServletRequestWrapper extends HttpServletRequestWrapper {

	// 没被包装过的HttpServletRequest（特殊场景，需求自己过滤）
	HttpServletRequest orgRequest;

	private ProcessStrategy processStrategy;

	/**
	 * html过滤
	 */
	private final static SqlFilter SQL_FILTER = new SqlFilter();

	public SqlInjectHttpServletRequestWrapper(HttpServletRequest request, ProcessStrategy processStrategy) {
		super(request);
		this.orgRequest = request;
		this.processStrategy = processStrategy;
	}

	@Override
	public String getParameter(String name) {
		String value = super.getParameter(name);
		if (StringUtil.isNotBlank(value)) {
			value = sqlInjectValidate(value);
		}
		return value;
	}

	@Override
	public String[] getParameterValues(String name) {
		String[] parameters = super.getParameterValues(name);
		if (parameters == null || parameters.length == 0) {
			return null;
		}

		for (int i = 0; i < parameters.length; i++) {
			parameters[i] = sqlInjectValidate(parameters[i]);
		}
		return parameters;
	}

	@Override
	public Map<String, String[]> getParameterMap() {
		Map<String, String[]> map = new LinkedHashMap<>();
		Map<String, String[]> parameters = super.getParameterMap();
		for (String key : parameters.keySet()) {
			String[] values = parameters.get(key);
			for (int i = 0; i < values.length; i++) {
				values[i] = sqlInjectValidate(values[i]);
			}
			map.put(key, values);
		}
		return map;
	}

	@Override
	public String getHeader(String name) {
		String value = super.getHeader(sqlInjectValidate(name));
		if (StringUtil.isNotBlank(value)) {
			value = sqlInjectValidate(value);
		}
		return value;
	}

	private String sqlInjectValidate(String value) {
		if (ProcessStrategy.THROW.equals(processStrategy)) {
			if (SQL_FILTER.sqlValidate(value)) {
				throw new IllegalArgumentException("请求参数不合法！");
			}
		} else {
			value = SQL_FILTER.sqlReplace(value);
		}
		return value;
	}

	/**
	 * 获取最原始的request
	 */
	public HttpServletRequest getOrgRequest() {
		return orgRequest;
	}

	/**
	 * 获取最原始的request
	 */
	public static HttpServletRequest getOrgRequest(HttpServletRequest request) {
		if (request instanceof SqlInjectHttpServletRequestWrapper) {
			return ((SqlInjectHttpServletRequestWrapper) request).getOrgRequest();
		}

		return request;
	}

}
